/***************************************************************************************************
 * Copyright (c) 2017-2021, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 *modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright notice,
 *this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *notice, this list of conditions and the following disclaimer in the
 *documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the names of its
 *contributors may be used to endorse or promote products derived from this
 *software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
 *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
 *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
 *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
 *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 **************************************************************************************************/
/*! \file
    \brief Templates implementing loading of convolution tiles mapped to GEMM A
   (output gradient tile) matrix from memory.

    This iterator assumes TensorNHWC layout of tensors in Global Memory.

    The iterator is specialized for each of the three convolution operators:
   forward propagation (Fprop), backward data gradient (Dgrad), and backward
   weight gradient (Wgrad).
*/

#pragma once

#include "cutlass/cutlass.h"
#include "cutlass/array.h"
#include "cutlass/coord.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/predicate_vector.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/tensor_view.h"
#include "cutlass/layout/pitch_linear.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/conv/convolution.h"
#include "cutlass/conv/conv2d_problem_size.h"
#include "cutlass/conv/threadblock/conv2d_params.h"

/////////////////////////////////////////////////////////////////////////////////////////////////

namespace cutlass {
namespace conv {
namespace threadblock {

/////////////////////////////////////////////////////////////////////////////////////////////////

template <typename Shape_, typename Element_, typename ThreadMap_,
          conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity,
          typename AccessType_ = cutlass::AlignedArray<
                  Element_, ThreadMap_::kElementsPerAccess> >
class Conv2dDgradOutputGradientTileAccessIteratorOptimized;
/////////////////////////////////////////////////////////////////////////////////////////////////

/////////////////////////////////////////////////////////////////////////////////////////////////
// Conv2dDgradOutputGradientTileAccessIteratorOptimized strided dgrad needs
// special handling to skip MMAs (Dx = Dy * w) on invalid filter positions
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Shape_, typename Element_, typename ThreadMap_,
          typename AccessType_>
class Conv2dDgradOutputGradientTileAccessIteratorOptimized<
        Shape_, Element_, ThreadMap_, conv::StrideSupport::kStrided,
        AccessType_> {
public:
    //
    // Types
    //
    using Shape = Shape_;
    using Element = Element_;
    using Layout = layout::TensorNHWC;
    using ThreadMap = ThreadMap_;
    using AccessType = AccessType_;
    using TensorRef = cutlass::TensorRef<Element, Layout>;
    using TensorCoord = typename Layout::TensorCoord;
    using Index = typename Layout::Index;
    using LongIndex = typename Layout::LongIndex;
    static IteratorAlgorithm const kIteratorAlgorithm =
            conv::IteratorAlgorithm::kOptimized;
    static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided;
    static int const kConvDim = 2;
    using ConvProblemSize = typename conv::Conv2dProblemSize;

    static int const kAccessesPerVector =
            ThreadMap::kElementsPerAccess / AccessType::kElements;

    static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
                  "Vectors implied by the thread map must be divisible by the "
                  "access type.");

    using Mask = uint64_t;

    //
    // Simpligying assertions
    //

    static_assert(ThreadMap::Iterations::kContiguous == 1,
                  "Require Iterations::kContiguous == 1");

    //
    // Parameters structure
    //

    using Params = Conv2dStridedDgradOutputGradientIteratorOptimizedParams;

private:
    Params const& params_;
    Conv2dProblemSize const& problem_size_;
    LongIndex iteration_contiguous_;
    LongIndex iteration_strided_;
    LongIndex iteration_vector_;

    // One pointer per access
    char const* pointer_[ThreadMap::Iterations::kStrided];

    int filter_k_;
    int filter_r_;
    int filter_s_;
    int start_r_;
    int start_s_;
    int64_t reset_bytes_s_;
    int64_t reset_bytes_r_;

    Index masks_[ThreadMap::Iterations::kStrided][kAccessesPerVector][2];

public:
    CUTLASS_HOST_DEVICE
    Conv2dDgradOutputGradientTileAccessIteratorOptimized(
            Params const& params, Conv2dProblemSize const& problem_size,
            Element const* ptr, int thread_idx,
            FastDivmod const& stride_h_divmod,
            FastDivmod const& stride_w_divmod, int start_r, int start_s,
            MatrixCoord const& threadblock_offset =
                    MatrixCoord()  // threadblock offset - units are whole CTA
                                   // tiles
            )
            : params_(params),
              problem_size_(problem_size),
              filter_k_(0),
              filter_r_(start_r),
              filter_s_(start_s),
              start_r_(start_r),
              start_s_(start_s) {
        layout::PitchLinearCoord thread_coord =
                ThreadMap::initial_offset(thread_idx);

        filter_k_ = threadblock_offset.column() + thread_coord.contiguous();

        reset_bytes_s_ = (problem_size_.num_gemm_k_filter_s(start_s_) - 1) *
                         params_.inc_next[0];

        reset_bytes_r_ = (problem_size_.num_gemm_k_filter_s(start_s_) - 1) *
                                 params_.inc_next[0] +
                         (problem_size_.num_gemm_k_filter_r(start_r_) - 1) *
                                 params_.inc_next[1];

        int offset_n[ThreadMap::Iterations::kStrided];
        int offset_p[ThreadMap::Iterations::kStrided];
        int offset_q[ThreadMap::Iterations::kStrided];

        int filter_r = filter_r_;
        int filter_s = filter_s_;

        if (problem_size_.mode == Mode::kConvolution) {
            filter_r = (problem_size_.R - 1 - filter_r);
            filter_s = (problem_size_.S - 1 - filter_s);
        }

        // Starting h, w positions for filter position in gemm_k=0
        int start_h, start_w;
        strided_dgrad_starting_coords(problem_size_, stride_h_divmod,
                                      stride_w_divmod, filter_r, filter_s,
                                      start_h, start_w);

        // Effective starting P and Q for filter position required for remapping
        // NHW rows
        int P = (problem_size_.H - start_h + problem_size_.stride_h - 1) /
                problem_size_.stride_h;
        int Q = (problem_size_.W - start_w + problem_size_.stride_w - 1) /
                problem_size_.stride_w;

        CUTLASS_PRAGMA_UNROLL
        for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
            pointer_[s] = reinterpret_cast<char const*>(ptr);

            int offset_npq =
                    (threadblock_offset.row() + thread_coord.strided() +
                     s * ThreadMap::Delta::kStrided) %
                    params_.tiled_rows_per_filter;

            // (STEP 1) [reorder NHW rows to start with same filter positions]
            offset_n[s] = offset_npq / (P * Q);
            int residual = offset_npq % (P * Q);

            int p = (residual / Q);
            int q = (residual % Q);

            int mapped_h = (start_h + p * problem_size_.stride_h);
            int mapped_w = (start_w + q * problem_size_.stride_w);

            // Access (p, q) coordinates for Dy tensor for filter position in
            // gemm_k=0 note that (h + pad_h - filter_r) and (w + pad_w -
            // filter_s) are ensured to be divisible by stride_h and stride_w
            offset_p[s] = (mapped_h + problem_size_.pad_h - filter_r) /
                          problem_size_.stride_h;
            offset_q[s] = (mapped_w + problem_size_.pad_w - filter_s) /
                          problem_size_.stride_w;

            // Intialize pointers for gemm_k=0
            TensorCoord coord{offset_n[s], offset_p[s], offset_q[s], filter_k_};

            pointer_[s] +=
                    params_.layout(coord) * sizeof_bits<Element>::value / 8;
        }

        //
        // Precompute mask predicates
        //
        clear_mask();

        CUTLASS_PRAGMA_NO_UNROLL
        for (int r = start_r; r < problem_size_.R;
             r += problem_size_.stride_h) {
            CUTLASS_PRAGMA_UNROLL
            for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided;
                 ++s_idx) {
                int p = offset_p[s_idx];

                p += (params_.conv_sign * (r / problem_size_.stride_h));

                bool pred = (offset_n[s_idx] < problem_size_.N && p >= 0 &&
                             p < problem_size_.P);

                CUTLASS_PRAGMA_UNROLL
                for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
                    masks_[s_idx][v_idx][0] |= (pred << r);
                }
            }
        }

        CUTLASS_PRAGMA_NO_UNROLL
        for (int s = start_s; s < problem_size_.S;
             s += problem_size_.stride_w) {
            CUTLASS_PRAGMA_UNROLL
            for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided;
                 ++s_idx) {
                int q = offset_q[s_idx];
                q += (params_.conv_sign * (s / problem_size_.stride_w));

                bool pred = (q >= 0 && q < problem_size_.Q);

                CUTLASS_PRAGMA_UNROLL
                for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
                    masks_[s_idx][v_idx][1] |= (pred << s);
                }
            }
        }

        CUTLASS_PRAGMA_UNROLL
        for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
            clear_mask(v_idx, (filter_k_ + v_idx * AccessType::kElements) >=
                                      problem_size.K);
        }

        set_iteration_index(0);
    }

    CUTLASS_HOST_DEVICE
    static Params getParams(Conv2dProblemSize const& problem_size,
                            Layout const& layout) {
        return Params(problem_size, layout, sizeof_bits<Element>::value,
                      {Shape::kRow, Shape::kColumn});
    }

private:
    /// Adds a pointer offset in units of element
    CUTLASS_HOST_DEVICE
    void add_byte_offset_(LongIndex byte_offset, LongIndex byte_reset = 0) {
        CUTLASS_PRAGMA_UNROLL
        for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
            pointer_[s] += byte_offset - byte_reset;
        }
    }

public:
    /// Overrides the internal iteration index
    CUTLASS_HOST_DEVICE
    void set_iteration_index(Index index) {
        iteration_vector_ = index % kAccessesPerVector;
        int residual_access = index / kAccessesPerVector;
        iteration_contiguous_ =
                residual_access % ThreadMap::Iterations::kContiguous;
        iteration_strided_ =
                residual_access / ThreadMap::Iterations::kContiguous;
    }

    /// Adds a pointer offset in units of Element
    CUTLASS_HOST_DEVICE
    void add_pointer_offset(LongIndex pointer_offset) {
        add_byte_offset_(pointer_offset * sizeof_bits<Element>::value / 8);
    }

    CUTLASS_HOST_DEVICE
    void advance() {
        int next_idx = 0;
        int64_t reset_bytes = 0;

        // Move filter_s by stride_w
        filter_s_ += problem_size_.stride_w;
        if (filter_s_ >= problem_size_.S) {
            // Restore filter_s
            filter_s_ = start_s_;

            // Move filter_r by stride_h
            filter_r_ += problem_size_.stride_h;
            if (filter_r_ < problem_size_.R) {
                next_idx = 1;

                // Restore bytes in q coordinate (Mma in filter s dimenstion)
                reset_bytes = reset_bytes_s_;

            } else {
                // Restore filter_r
                filter_r_ = start_r_;

                next_idx = 2;

                // Restore bytes in p and q coordinate (Mma in filter s and r
                // dimenstion)
                reset_bytes = reset_bytes_r_;
            }
        }

        // offset pointers by offset_bytes
        add_byte_offset_(params_.inc_next[next_idx] - reset_bytes);

        if (next_idx == 2) {
            filter_k_ += params_.filter_k_delta;
        }

        CUTLASS_PRAGMA_UNROLL
        for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
            clear_mask(v_idx, (filter_k_ + v_idx * AccessType::kElements) >=
                                      problem_size_.K);
        }
    }

    /// Clears the predicates
    CUTLASS_HOST_DEVICE
    void clear_mask(bool clear = true) {
        CUTLASS_PRAGMA_UNROLL
        for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
            CUTLASS_PRAGMA_UNROLL
            for (int v = 0; v < kAccessesPerVector; ++v) {
                masks_[s][v][0] = clear ? Mask(0) : masks_[s][v][0];
                masks_[s][v][1] = clear ? Mask(0) : masks_[s][v][1];
            }
        }
    }

    /// Clears the predicates
    CUTLASS_HOST_DEVICE
    void clear_mask(int v, bool clear = true) {
        CUTLASS_PRAGMA_UNROLL
        for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
            masks_[s][v][0] = clear ? Mask(0) : masks_[s][v][0];
            masks_[s][v][1] = clear ? Mask(0) : masks_[s][v][1];
        }
    }

    /// Returns true if the current coordinate is within the output tensor Dy
    CUTLASS_HOST_DEVICE
    bool valid() const {
        return (masks_[iteration_strided_][iteration_vector_][0] &
                (Index(1) << filter_r_)) &&
               (masks_[iteration_strided_][iteration_vector_][1] &
                (Index(1) << filter_s_));
    }

    /// Returns a pointer to the vector starting at the current coordinate
    CUTLASS_HOST_DEVICE
    AccessType const* get() const {
        return reinterpret_cast<AccessType const*>(
                       pointer_[iteration_strided_]) +
               iteration_vector_;
    }

    /// Increments to the next memory access
    CUTLASS_HOST_DEVICE
    Conv2dDgradOutputGradientTileAccessIteratorOptimized& operator++() {
        ++iteration_vector_;
        if (iteration_vector_ < kAccessesPerVector) {
            return *this;
        }
        iteration_vector_ = 0;

        ++iteration_contiguous_;
        if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
            return *this;
        }
        iteration_contiguous_ = 0;
        ++iteration_strided_;
        if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
            return *this;
        }
        iteration_strided_ = 0;

        return *this;
    }

    /// Determines whether the Implicit GEMM can execute the given problem.
    CUTLASS_HOST_DEVICE
    static Status can_implement(Conv2dProblemSize const& problem_size) {
        // check alignment constraint on iterator's contiguous dimension
        if (problem_size.K % AccessType::kElements) {
            return Status::kErrorInvalidProblem;
        }

        // Limit on filter size
        if (problem_size.R > 32 || problem_size.S > 32) {
            return Status::kErrorNotSupported;
        }

        return Status::kSuccess;
    }
};

/////////////////////////////////////////////////////////////////////////////////////////////////
// Conv2dDgradOutputGradientTileAccessIteratorOptimized unity stride dgrad is
// optimized for dgrad with problem stride = {1x1}
/////////////////////////////////////////////////////////////////////////////////////////////////

template <typename Shape_, typename Element_, typename ThreadMap_,
          typename AccessType_>
class Conv2dDgradOutputGradientTileAccessIteratorOptimized<
        Shape_, Element_, ThreadMap_, conv::StrideSupport::kUnity,
        AccessType_> {
public:
    //
    // Types
    //

    using Shape = Shape_;
    using Element = Element_;
    using Layout = layout::TensorNHWC;
    using TensorCoord = typename Layout::TensorCoord;
    using ThreadMap = ThreadMap_;
    using AccessType = AccessType_;
    using TensorRef = cutlass::TensorRef<Element, Layout>;
    using Index = typename Layout::Index;
    using LongIndex = typename Layout::LongIndex;
    static IteratorAlgorithm const kIteratorAlgorithm =
            conv::IteratorAlgorithm::kOptimized;
    static StrideSupport const kStrideSupport = conv::StrideSupport::kUnity;
    static int const kConvDim = 2;
    using ConvProblemSize = typename conv::Conv2dProblemSize;

    static int const kAccessesPerVector =
            ThreadMap::kElementsPerAccess / AccessType::kElements;

    static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
                  "Vectors implied by the thread map must be divisible by the "
                  "access type.");

    using Mask = uint64_t;

    //
    // Simplifying assertions
    //
    static_assert(ThreadMap::Iterations::kContiguous == 1,
                  "Require Iterations::kContiguous == 1");

    //
    // Parameters structure
    //

    using Params = Conv2dDgradOutputGradientIteratorOptimizedParams;

private:
    Params const& params_;
    Conv2dProblemSize const& problem_size_;
    LongIndex iteration_contiguous_;
    LongIndex iteration_strided_;
    LongIndex iteration_vector_;

    // One pointer per access
    char const* pointer_[ThreadMap::Iterations::kStrided];

    // current filter position (r, s)
    int filter_r_;
    int filter_s_;
    int filter_k_;

    Index masks_[ThreadMap::Iterations::kStrided][kAccessesPerVector][2];

public:
    CUTLASS_HOST_DEVICE
    Conv2dDgradOutputGradientTileAccessIteratorOptimized(
            Params const& params,
            Conv2dProblemSize const& problem_size, Element const* ptr,
            int thread_idx,
            MatrixCoord const& threadblock_offset =
                    MatrixCoord()  // tile index - units are threadblock-scoped
                                   // tiles
            )
            : params_(params),
              problem_size_(problem_size),
              filter_k_(0),
              filter_r_(0),
              filter_s_(0) {
        layout::PitchLinearCoord thread_coord =
                ThreadMap::initial_offset(thread_idx);

        filter_k_ = threadblock_offset.column() + thread_coord.contiguous();

        int offset_n[ThreadMap::Iterations::kStrided];
        int offset_h[ThreadMap::Iterations::kStrided];
        int offset_w[ThreadMap::Iterations::kStrided];

        CUTLASS_PRAGMA_UNROLL
        for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
            pointer_[s] = reinterpret_cast<char const*>(ptr);

            int offset_nhw = threadblock_offset.row() + thread_coord.strided() +
                             s * ThreadMap::Delta::kStrided;

            // The subseqnet fast_divmod() operations are equivalent to the
            // following logical computation:
            //
            //
            //  offset_n[s] = offset_nhw / (problem_size_.H * problem_size_.W);
            //  int residual = offset_nhw % (problem_size_.H * problem_size_.W);
            //
            //  offset_h[s] = residual / problem_size_.W;
            //  offset_w[s] = residual % problem_size_.W;
            //

            int residual;

            params_.hw_divmod(offset_n[s], residual, offset_nhw);
            params_.w_divmod(offset_h[s], offset_w[s], residual);

            TensorCoord coord =
                    at_(offset_n[s], offset_h[s], offset_w[s], 0, 0);

            pointer_[s] +=
                    params_.layout(coord) * sizeof_bits<Element>::value / 8;
        }

        clear_mask();

        CUTLASS_PRAGMA_NO_UNROLL
        for (int r = 0; r < problem_size_.R; ++r) {
            CUTLASS_PRAGMA_UNROLL
            for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided;
                 ++s_idx) {
                int r_ = r;
                if (problem_size_.mode == Mode::kConvolution) {
                    r_ = problem_size_.R - 1 - r;
                }

                int p = offset_h[s_idx] + problem_size_.pad_h -
                        r_ * problem_size_.dilation_h;

                bool pred = (offset_n[s_idx] < problem_size_.N && p >= 0 &&
                             p < problem_size_.P);

                CUTLASS_PRAGMA_UNROLL
                for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
                    masks_[s_idx][v_idx][0] |= (pred << r);
                }
            }
        }

        CUTLASS_PRAGMA_NO_UNROLL
        for (int s = 0; s < problem_size_.S; ++s) {
            CUTLASS_PRAGMA_UNROLL
            for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided;
                 ++s_idx) {
                int s_ = s;
                if (problem_size_.mode == Mode::kConvolution) {
                    s_ = problem_size_.S - 1 - s;
                }

                int q = offset_w[s_idx] + problem_size_.pad_w -
                        s_ * problem_size_.dilation_w;

                bool pred = (q >= 0 && q < problem_size_.Q);

                CUTLASS_PRAGMA_UNROLL
                for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
                    masks_[s_idx][v_idx][1] |= (pred << s);
                }
            }
        }

        CUTLASS_PRAGMA_UNROLL
        for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
            clear_mask(v_idx, filter_k_ >= problem_size.K);
        }

        set_iteration_index(0);
    }

    CUTLASS_HOST_DEVICE
    static Params getParams(Conv2dProblemSize const& problem_size,
                            Layout const& layout) {
        return Params(
                problem_size, layout, sizeof_bits<Element>::value,
                {Shape::kRow, Shape::kColumn}, ThreadMap::kThreads,
                ThreadMap::kElementsPerAccess,
                {ThreadMap::Iterations::kContiguous,
                 ThreadMap::Iterations::kStrided},
                {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided});
    }

private:
    /// Returns the coordinate in the output gradient tensor dy that is
    /// correspoinding to
    // activation nhw and filter position k, r, s
    CUTLASS_HOST_DEVICE
    TensorCoord at_(int n, int h, int w, int r, int s) const {
        if (problem_size_.mode == Mode::kConvolution) {
            r = problem_size_.R - 1 - r;
            s = problem_size_.S - 1 - s;
        }

        int p = h + problem_size_.pad_h - r * problem_size_.dilation_h;
        int q = w + problem_size_.pad_w - s * problem_size_.dilation_w;

        return TensorCoord(n, p, q, filter_k_);
    }

    /// Adds a pointer offset in units of element
    CUTLASS_HOST_DEVICE
    void add_byte_offset_(LongIndex byte_offset) {
        CUTLASS_PRAGMA_UNROLL
        for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
            pointer_[s] += byte_offset;
        }
    }

public:
    /// Overrides the internal iteration index
    CUTLASS_HOST_DEVICE
    void set_iteration_index(Index index) {
        iteration_vector_ = index % kAccessesPerVector;
        int residual_access = index / kAccessesPerVector;
        iteration_contiguous_ =
                residual_access % ThreadMap::Iterations::kContiguous;
        iteration_strided_ =
                residual_access / ThreadMap::Iterations::kContiguous;
    }

    /// Adds a pointer offset in units of element
    CUTLASS_HOST_DEVICE
    void add_pointer_offset(LongIndex pointer_offset) {
        add_byte_offset_(pointer_offset * sizeof_bits<Element>::value / 8);
    }

    CUTLASS_HOST_DEVICE
    void advance() {
        int next_idx = 0;

        // moves to the next tile
        ++filter_s_;
        if (filter_s_ == problem_size_.S) {
            filter_s_ = 0;
            ++filter_r_;

            if (filter_r_ < problem_size_.R) {
                next_idx = 1;
            } else {
                filter_r_ = 0;
                next_idx = 2;
            }
        }

        add_byte_offset_(params_.inc_next[next_idx]);

        if (next_idx == 2) {
            filter_k_ += params_.filter_k_delta;
        }

        CUTLASS_PRAGMA_UNROLL
        for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) {
            clear_mask(v_idx, (filter_k_ + v_idx * AccessType::kElements) >=
                                      problem_size_.K);
        }
    }

    /// Clears the predicates
    CUTLASS_HOST_DEVICE
    void clear_mask(bool clear = true) {
        CUTLASS_PRAGMA_UNROLL
        for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
            CUTLASS_PRAGMA_UNROLL
            for (int v = 0; v < kAccessesPerVector; ++v) {
                masks_[s][v][0] = clear ? Mask(0) : masks_[s][v][0];
                masks_[s][v][1] = clear ? Mask(0) : masks_[s][v][1];
            }
        }
    }

    /// Clears the predicates
    CUTLASS_HOST_DEVICE
    void clear_mask(int v, bool clear = true) {
        CUTLASS_PRAGMA_UNROLL
        for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
            masks_[s][v][0] = clear ? Mask(0) : masks_[s][v][0];
            masks_[s][v][1] = clear ? Mask(0) : masks_[s][v][1];
        }
    }

    CUTLASS_HOST_DEVICE
    bool valid() {
        return (masks_[iteration_strided_][iteration_vector_][0] &
                (Index(1) << filter_r_)) &&
               (masks_[iteration_strided_][iteration_vector_][1] &
                (Index(1) << filter_s_));
    }

    /// Returns a pointer to the vector starting at the current coordinate
    CUTLASS_HOST_DEVICE
    AccessType const* get() const {
        return reinterpret_cast<AccessType const*>(
                       pointer_[iteration_strided_]) +
               iteration_vector_;
    }

    /// Increments to the next memory access
    CUTLASS_HOST_DEVICE
    Conv2dDgradOutputGradientTileAccessIteratorOptimized& operator++() {
        ++iteration_vector_;
        if (iteration_vector_ < kAccessesPerVector) {
            return *this;
        }
        iteration_vector_ = 0;

        ++iteration_contiguous_;
        if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
            return *this;
        }
        iteration_contiguous_ = 0;

        ++iteration_strided_;
        if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
            return *this;
        }
        iteration_strided_ = 0;

        return *this;
    }

    /// Determines whether the Implicit GEMM can execute the given problem.
    CUTLASS_HOST_DEVICE
    static Status can_implement(Conv2dProblemSize const& problem_size) {
        // This is specialized for unit stride
        if (problem_size.stride() != MatrixCoord({1, 1})) {
            return Status::kErrorNotSupported;
        }

        // check alignment constraint on iterator's contiguous dimension
        if (problem_size.K % AccessType::kElements) {
            return Status::kErrorNotSupported;
        }

        // Limit on filter size
        if (problem_size.R > 32 || problem_size.S > 32) {
            return Status::kErrorNotSupported;
        }
        return Status::kSuccess;
    }
};

/////////////////////////////////////////////////////////////////////////////////////////////////

}  // namespace threadblock
}  // namespace conv
}  // namespace cutlass

/////////////////////////////////////////////////////////////////////////////////////////////////
